Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic implementaton of Transition node #439

Merged
merged 9 commits into from
Jan 22, 2025
Merged

Generic implementaton of Transition node #439

merged 9 commits into from
Jan 22, 2025

Conversation

wouterwln
Copy link
Member

This allows a generic transition function for Categorical distributions, such as POMDPs. It allows model structures like this:

@model function pomdp_demo(y)
    initial_state ~ Categorical([0.1, 0.4, 0.5]) # Prior on initial state
    control ~ Categorical([0.1, 0.4, 0.2, 0.3]) # Prior on a control variable
    A ~ TensorDirichlet(ones(3,3,4)) # Prior on transition function

    next_state ~ Transition(initial_state, A, control)
    y ~ Transition(next_state, diageye(3)) # Likelihood model is just identity
    
end

constraints = @constraints begin
    q(initial_state, next_state, control, A) = q(initial_state, next_state, control)q(A)
end
initialization = @initialization begin
    q(A) = TensorDirichlet(ones(3,3,4))
end

result = infer(model = pomdp_demo(),
data = (y = [0, 1, 0],),
constraints = constraints,
iterations = 10,
initialization = initialization,
free_energy=true)

The interface is: out ~ Transition(in, parameters, additional_interfaces...), so we can add all kinds of different interfaces to this transition node. Let's say we have a categorical "context" state which should influence the transition function. We can add this easily with the same interface and node:

@model function pomdp_demo_with_context(y)
    initial_state ~ Categorical([0.1, 0.4, 0.5]) # Prior on initial state
    context ~ Categorical([0.2, 0.2, 0.2, 0.2, 0.2])
    control ~ Categorical([0.1, 0.4, 0.2, 0.3]) # Prior on a control variable

    A ~ TensorDirichlet(ones(3,3,5,4) .+2) # Prior on transition function

    next_state ~ Transition(initial_state, A, context, control)
    y ~ Transition(next_state, diageye(3)) # Likelihood model is just identity
    
end

constraints = @constraints begin
    q(initial_state, next_state, control, context, A) = q(initial_state, next_state, control, context)q(A)
end
initialization = @initialization begin
    q(A) = TensorDirichlet(ones(3,3,5,4) .+ 2)
end

result = infer(model = pomdp_demo_with_context(),
data = (y = [0, 1, 0],),
constraints = constraints,
iterations = 10,
initialization = initialization,
free_energy=true,)

If anything is not clear LMK.

Copy link

codecov bot commented Jan 21, 2025

Codecov Report

Attention: Patch coverage is 87.50000% with 6 lines in your changes missing coverage. Please review.

Project coverage is 73.53%. Comparing base (1d28b19) to head (e15f559).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
src/nodes/predefined/transition.jl 83.78% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #439      +/-   ##
==========================================
+ Coverage   73.36%   73.53%   +0.17%     
==========================================
  Files         192      194       +2     
  Lines        5530     5582      +52     
==========================================
+ Hits         4057     4105      +48     
- Misses       1473     1477       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@bvdmitri
Copy link
Member

I'm not sure why tests are failing, some Aqua related stuff?

@bvdmitri bvdmitri merged commit 58a213a into main Jan 22, 2025
6 checks passed
@bvdmitri bvdmitri deleted the generic-transition branch January 22, 2025 14:25
@apashea
Copy link

apashea commented Jan 25, 2025

Great work! We appreciate all you are doing!

Code works fine on my machine. So here's where I'm at in terms of a rough sketch of what I'm going for (if it's even possible):
I would essentially like to recreate the dynamics, to a fair degree, of the POMDP structure very, very commonly (cannot emphasize that enough) found in Active Inference for purely discrete state-space models. I've been hoping for this since early last year after significant time spent looking through the RxInfer examples--especially HMM to try to figure out how to add a control and state transition model, as well as van de Laar et. al's LAIF scripts--and docs, un- and re-installing packages, etc.

I would totally enjoy putting out an RxInfer tutorial or two via livestreams with the Institute if we could nail this down (as I see potential for expanding from there, e.g., hierarchical modeling and modifying the POMDP beyond the original scheme), as I see most of the requirements are already accomplished in your code examples. I think you will receive much more support from the ActInf community, and probably our Institute learners (where we are collectively reading and studying the Friston et. al textbook annually, i.e. newcomers to the Institute who program are learning about POMDPs first), were you to establish a tutorial on making this key model architecture and process. The end goal, ultimately, is to understand RxInfer well enough in terms of our own priors to where we can then surpass them and, say, build hierarchical, multimodal, and other kinds of interesting architectures made possible by RxInfer.

Primary requirements:

  • POMDP where the A, B, and D matrices (nomenclature, see commented code below) can all be learned (parameterized, as is already done in your snippet for your likelihood model 'A')
  • Preferences (C) are part of the pragmatic value term in computing expected free energy
  • State and policy inference, i.e. q Q(s) and a Q(pi). Maybe this requires two @model's or some other scheme, I'm guessing, from what I remember of the LAIF script.

So a sort of rough sketch, where I have commented out desired but unintegrated components, is as follows:

using RxInfer;

@model function pomdp_demo(o)
    #A ~ TensorDirichlet(diageye(3) * 100 .+ 0.0001) # Likelihood model prior $P(o_t|s_t)$
    B ~ TensorDirichlet(ones(3,3,4)) # State transition model prior $P(s_{t+1}|s_t,pi)$
    #C ~ Categorical([0.0, 0.0, 1.0]) # Prior over observations (preferences which typically remain static) $P(o_0)$
    D ~ Categorical([0.1, 0.4, 0.5]) # Prior over initial states $P(s_0)$
    E ~ Categorical([0.1, 0.4, 0.2, 0.3]) # Prior on a control variable $P(pi_0)$
    
    q_s ~ Transition(D, B, E)
    #o ~ Transition(q_s, A)
    o ~ Transition(q_s, diageye(3))
    
end

constraints = @constraints begin
    q(D, q_s, E, B) = q(D, q_s, E)q(B)  #?
end
initialization = @initialization begin
    q(B) = TensorDirichlet(ones(3,3,4))
    #q(A) = TensorDirichlet(diageye(3) * 100 .+ 0.0001)
end

result = infer(model = pomdp_demo(),
data = (o = [0, 1, 0, ],),
constraints = constraints,
iterations = 10,
initialization = initialization,
free_energy=true)

References to the general ActInf POMDP:

I'm happy to discuss this further, clarify, revise my expectations, delete this comment and discuss elsewhere, etc. etc.
Let me know!

Best,
Andrew

@wouterwln
Copy link
Member Author

Hi @apashea , thanks for trying out the new functionality! Actually, the functionality you describe is already available, but requires some additional knowledge (that I'm currently working on to disentangle :) ). First of all, the Transition node without controls (so for the likelihood model) was around for a long time already, and used the MatrixDirichlet distribution as a prior. MatrixDirichlet is kind of an artifact of the previous RxInfer versions, and I want to fully deprecate it and replace it with TensorDirichlet, which is a strict generalization (see ReactiveBayes/ExponentialFamily.jl#227).

Now, as for the goal specification, it is important to notice that in RxInfer, we should make an explicit distinction between current observations and future observations. On future observations (so observations which will not be constrained by data), we can put priors, which will become the goal priors. I'll try to demonstrate this with the short code example below, in which I will also run likelihood matrix inference:

using RxInfer;

@model function pomdp_demo(o, u_current)
    # -------- Prior specification --------
    A ~ MatrixDirichlet(diageye(3) * 100 .+ 0.0001) # Likelihood model prior $P(o_t|s_t)$
    B ~ TensorDirichlet(ones(3,3,4)) # State transition model prior $P(s_{t+1}|s_t,pi)$
    s_0 ~ Categorical([0.1, 0.4, 0.5]) # Prior over initial states $P(s_0)$

    # -------- Model parameter inference --------

    # By supplying `u_current`, we can use the `Transition` node to learn the transition probabilities.
    s_current ~ Transition(s_0, B, u_current) # Transition model $P(s_{t+1}|s_t,pi)$
    # By supplying `o`, we can use the `Transition` node to learn the likelihood probabilities.
    o ~ Transition(s_current, A)

    # -------- Control variable inference (planning) --------
    u_next ~ Categorical([0.1, 0.4, 0.2, 0.3]) # Prior on a control variable $P(pi_0)$
    
    s_next ~ Transition(s_current, B, u_next) # Predict next state
    o_next ~ Transition(s_next, A) # Predict next observation

    o_next ~ Categorical([0.0, 0.0, 1.0]) # Prior over future observations
    
end

# -------- Constraints --------
constraints = @constraints begin
    q(s_0, s_current, u_current, B) = q(s_0, s_current, u_current)q(B)
    q(s_current, s_next, u_next, B) = q(s_current, s_next, u_next)q(B)
    q(s_current, o, A) = q(s_current, o)q(A)
    q(s_next, o_next, A) = q(s_next, o_next)q(A)
end
initialization = @initialization begin
    q(B) = TensorDirichlet(ones(3,3,4))
    q(A) = MatrixDirichlet(diageye(3) * 100 .+ 0.0001)
end

result = infer(model = pomdp_demo(),
data = (o = [0, 1, 0], u_current = UnfactorizedData([0, 1, 0, 0])),
constraints = constraints,
iterations = 10,
initialization = initialization,
)

Now there's some additional things I have to explain:

  • I added the previous control to the data. This is because if we have the previous control and the new observation, we can learn B. Notice that we can use the Transition node both to learn the parameters when all data is available, as well as performing planning when the data is unavailable, using our leaned B! This also means that I make a distinction between the current observation (which is made) and the next observation (which we want to fix at a specific value)
  • I use UnfactorizedData for u_current (documentation can be found here https://reactivebayes.github.io/RxInfer.jl/stable/manuals/constraints-specification/#RxInfer.UnfactorizedData). This is because as of now, RxInfer can only do inference for generic Transition nodes if there is a joint distribution learnt over all categorical variables that are attached to this node (and by not specifying UnfactorizedData, RxInfer will attempt to learn q(u_current)q(s_0, s_current) instead of q(u_current, s_0, s_current), and the Transition node is not programmed to do that. ). So that is why that statement is there.
  • I removed the free_energy=true keyword from the inference procedure, this is because there is a bug somewhere in our free energy computation. We are aware of the bugs (Free energy computation of Transition is broken RxInfer.jl#410) and where they occur (Decomposition of marginals does not decompose namedtuple when other marginals are present #440), but we haven't found the resources yet to patch them. Just know that we are well aware and that a fix is on the way :)

I hope this clarifies the new functionality a bit

@apashea
Copy link

apashea commented Jan 26, 2025

Thank you, Wouter, for your help and advisement! Super helpful to see more clarity concisely.

Noted on moving towards deprecating MatrixDirichlet and fixing free_energy .

Apologies, I may be way off course here, but now I am attempting to:

  • initialize with all uniform priors, aside from keep index 3 the desired goal prior over observations
  • learn A, B, and s_0
  • run an action-perception loop, which I'm guessing should be done with the streamline inference approach (you had a tutorial for that, which some months ago I understood but didn't make myself notes to jump right back in and it's kind of information overload right now...). Is my approach below reinitializing the model at each iteration... (which I do not wish to do of course)? The model in this particular case doesn't end up moving towards an action which would realize its expectations (action 3 or 4).

Note I'm using these little [var] tags to keep track of these, as were I to teach others concisely, I would want to be able to specify "___ is the line(s) of code to be included if you want the model to learn parameter ___ ."

using RxInfer;

@model function pomdp_demo(o, u_current)
    # -------- Prior specification --------
    A ~ MatrixDirichlet(ones(3,3))   # Likelihood model prior $P(o_t|s_t)$
    B ~ TensorDirichlet(ones(3,3,4)) # State transition model prior $P(s_{t+1}|s_t,pi)$
    s_0 ~ Categorical(ones(3) ./ 3)  # Prior over initial states $P(s_0)$

    # -------- Model parameter inference --------
    # By supplying `u_current`, we can use the `Transition` node to learn the transition probabilities.
    s_current ~ Transition(s_0, B, u_current) # Transition model $P(s_{t+1}|s_t,pi)$             # [s_current]
    # By supplying `o`, we can use the `Transition` node to learn the likelihood probabilities.
    o ~ Transition(s_current, A)                                                                 # [o]

    # -------- Control variable inference (planning) --------
    u_next ~ Categorical([0.25, 0.25, 0.25, 0.25]) # Prior on a control variable $P(pi_0)$
    
    s_next ~ Transition(s_current, B, u_next) # Predict next state     # [s_next]
    o_next ~ Transition(s_next, A) # Predict next observation          # [o_next]

    o_next ~ Categorical([0.0, 0.0, 1.0]) # Prior over future observations
    
end

# -------- Constraints --------
constraints = @constraints begin
    q(s_0, s_current, u_current, B) = q(s_0, s_current, u_current)q(B)  # Transition model prior learning  # [s_current]
    q(s_current, s_next, u_next, B) = q(s_current, s_next, u_next)q(B)  # Control prior learning           # [s_next]
    q(s_current, o, A) = q(s_current, o)q(A)                            # Likelihood model learning        # [o]
    q(s_next, o_next, A) = q(s_next, o_next)q(A)                        # Predict next observation        # [o_next]
end
initialization = @initialization begin
    q(B) = TensorDirichlet(ones(3,3,4))                                 # Transition model prior learning  # [s_current]
    q(A) = MatrixDirichlet(ones(3,3))                                   # Likelihood model prior learning  # [o]
    q(s_0) = Categorical(ones(3) ./ 3)                                 # Prior over initial states $P(s_0)$
end

# Single inference originally shared:
# result = infer(model = pomdp_demo(),
# data = (o = [0, 1, 0], u_current = UnfactorizedData([0, 1, 0, 0])),
# constraints = constraints,
# iterations = 10,
# initialization = initialization,
# )

# Action-perception loop
action_vec = Vector([0, 1, 0, 0])             # Initial forced action
action_vec_uf = UnfactorizedData(action_vec)  # Required for infer() (?)

for i in 1:10
    # Deterministic environment: u1 -> o1, u2 -> o2, u3 -> o3, u4 -> o3
    if argmax(action_vec) == 1
        o = [1, 0, 0]
    elseif argmax(action_vec) == 2
        o = [0, 1, 0]
    else
        o = [0, 0, 1]
    end
    println("Iteration: $i: Action chosen: = $(argmax(action_vec)), Observation elicited = $(argmax(o))")
    result = infer(model = pomdp_demo(),
    data = (o = o, u_current = action_vec_uf),
    constraints = constraints,
    iterations = 100,
    initialization = initialization,
    )
    action_vec = Int.(1:length(result.posteriors[:u_next][1].p) .== argmax(result.posteriors[:u_next][100].p))  # get next action
    #println("q(A): ", result.posteriors[:A][10])
    #println("q(B): ", result.posteriors[:B][10])
    #println("q(s_0): ", result.posteriors[:s_0][1])
    #println("q(u_next): ", result.posteriors[:u_next][1].p)
    action_vec_uf = UnfactorizedData(action_vec)
end

@wouterwln
Copy link
Member Author

There are two very simple reasons why your inference does not converge.

  1. You, for every datapoint you get, reset the prior distribution over A and B to the uniform distributions, so the model has no way of learning over different epochs. You can fix this with the following pattern:
@model function pomdp(p_A, .....)
    A ~ p_A
    ...
end
...
p_A = MatrixDirichlet(prior_parameters)
...
for i in 1:10
    ...
    result = infer(model = pomdp(p_A = p_A
    ...
    p_A = result.posteriors[:A]
    ...
end
  1. The prior you set for A is uniform, so you don't give the model any incentive to trust its own observations, I think you're better of with setting MatrixDirichlet(diageye(3) .+ 1.0) as the prior, or anything like that, such that the model at least learns to trust its own senses.

@LearnableLoopAI
Copy link

Very pleased to see this. Thank you so much. I'm looking forward to use this as soon as I can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants